File size: 28,272 Bytes
9fd1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
import datetime
import functools
import os
import pathlib
import shutil
import time
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import datasets.distributed
import torch
import torch.distributed._functional_collectives
import torch.distributed.checkpoint
import torch.distributed.checkpoint.stateful
from diffusers.hooks import HookRegistry, ModelHook
from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
from torch.distributed._composable.replicate import replicate
from torch.distributed.checkpoint.state_dict import (
    StateDictOptions,
    get_model_state_dict,
    set_model_state_dict,
)
from torch.distributed.tensor import DTensor, Shard

from finetrainers._metadata import ContextParallelModelPlan, CPInput, CPOutput, TransformerRegistry
from finetrainers.data import DPDataLoader
from finetrainers.logging import get_logger
from finetrainers.utils import enable_determinism, get_device_info, get_submodule_by_name, unwrap_module
from finetrainers.utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES

from .base import BaseCheckpointer, BaseParallelBackend


if TYPE_CHECKING:
    from finetrainers import optimizer


_device_type, _device_module = get_device_info()
logger = get_logger()


class PytorchDTensorParallelBackend(BaseParallelBackend):
    def __init__(
        self,
        world_size: int,
        pp_degree: int = 1,
        dp_degree: int = 1,
        dp_shards: int = -1,
        cp_degree: int = 1,
        tp_degree: int = 1,
        backend: str = "nccl",
        timeout: int = 180,
        logging_dir: Optional[str] = None,
        output_dir: Optional[str] = None,
        gradient_accumulation_steps: Optional[int] = None,
    ) -> None:
        super().__init__()

        self._world_size = world_size
        self._pp_degree = pp_degree
        self._dp_degree = dp_degree
        self._dp_shards = dp_shards
        self._cp_degree = cp_degree
        self._tp_degree = tp_degree
        self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None
        self._logging_dir = (
            self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None
        )
        self._backend = backend
        self._timeout = timeout

        for degree in [pp_degree, dp_degree, dp_shards, cp_degree, tp_degree]:
            if degree < 1:
                raise ValueError(f"Parallel degree must be at least 1, got {degree}.")

        if dp_shards * pp_degree * dp_degree * cp_degree * tp_degree != world_size:
            raise ValueError(
                f"World size {world_size} must be divisible by the product of all parallel degrees and data parallel shards."
            )

        torch.distributed.init_process_group(backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout))
        _device_module.set_device(self.local_rank)

        logger.info(
            f"Initialized parallel state with:\n"
            f"  - World size: {world_size}\n"
            f"  - Pipeline parallel degree: {pp_degree}\n"
            f"  - Data parallel degree: {dp_degree}\n"
            f"  - Context parallel degree: {cp_degree}\n"
            f"  - Tensor parallel degree: {tp_degree}\n"
            f"  - Data parallel shards: {dp_shards}\n"
        )

        self._mesh: torch.distributed.DeviceMesh = None

    def enable_determinism(self, seed):
        world_mesh = self.get_mesh()
        enable_determinism(seed, world_mesh)

    def apply_ddp(
        self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None
    ) -> torch.nn.Module:
        if device_mesh is None:
            device_mesh = self.get_mesh()
        apply_ddp(model, device_mesh)
        logger.debug("Applied PytorchDTensorParallel::apply_ddp to model.")
        return model

    def apply_fsdp2(
        self,
        model: torch.nn.Module,
        param_dtype: torch.dtype,
        reduce_dtype: torch.dtype,
        output_dtype: torch.dtype,
        pp_enabled: bool = False,
        cpu_offload: bool = False,
        device_mesh: Optional[torch.distributed.DeviceMesh] = None,
    ) -> torch.nn.Module:
        if device_mesh is None:
            device_mesh = self.get_mesh()
        apply_fsdp2(model, device_mesh, param_dtype, reduce_dtype, output_dtype, pp_enabled, cpu_offload)
        logger.debug("Applied PytorchDTensorParallel::apply_fsdp2 to model.")
        return model

    def apply_context_parallel(
        self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None
    ) -> torch.nn.Module:
        if device_mesh is None:
            device_mesh = self.get_mesh()
        apply_context_parallel(model, device_mesh)
        logger.debug("Applied PytorchDTensorParallel::apply_context_parallel to model.")
        return model

    def prepare_model(self, model: torch.nn.Module) -> torch.nn.Module:
        return model

    def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset:
        if self._dp_degree == 1:
            return dataset
        dp_mesh = self.get_mesh()["dp_replicate"]
        dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size()
        dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size)
        logger.debug("PytorchDTensorParallelBackend::prepare_dataset completed!")
        return dataset

    def prepare_dataloader(
        self, dataset: torch.utils.data.IterableDataset, batch_size: int, num_workers: int, pin_memory: bool
    ) -> DPDataLoader:
        if self._dp_degree == 1:
            dp_local_rank = 0
        else:
            dp_mesh = self.get_mesh()["dp_replicate"]
            dp_local_rank = dp_mesh.get_local_rank()
        dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers)
        logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!")
        return dataloader

    def prepare_optimizer(self, optimizer, lr_scheduler):
        logger.debug("PytorchDTensorParallelBackend::prepare_optimizer completed!")
        return optimizer, lr_scheduler

    def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
        def _get_mesh():
            if name is None:
                return self._mesh
            try:
                return self._mesh[name]
            except (KeyError, RuntimeError):
                if self._mesh.ndim == 0:
                    return None
                return self._mesh

        if self._mesh is not None:
            return _get_mesh()

        mesh_list = [
            ("pp", self._pp_degree),
            ("dp_replicate", self._dp_degree),
            ("dp_shard", self._dp_shards),
            ("cp", self._cp_degree),
            ("tp", self._tp_degree),
        ]
        mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1]
        names = [x[0] for x in mesh_list]
        degrees = [x[1] for x in mesh_list]
        mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names)

        dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], []

        if self.data_replication_enabled:
            dp_mesh_names.append("dp_replicate")
            dp_cp_mesh_names.append("dp_replicate")
        if self.data_sharding_enabled:
            dp_mesh_names.append("dp_shard")
            dp_cp_mesh_names.append("dp_shard")
            dp_shard_cp_mesh_names.append("dp_shard")
        if self.context_parallel_enabled:
            dp_cp_mesh_names.append("cp")
            dp_shard_cp_mesh_names.append("cp")

        if len(dp_mesh_names) > 0:
            mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp")
        if len(dp_cp_mesh_names) > 0:
            mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp")
        if len(dp_shard_cp_mesh_names) > 0:
            mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp")

        logger.debug(f"Device mesh: {mesh}")
        self._mesh = mesh
        return _get_mesh()

    def get_checkpointer(self, *args, **kwargs):
        return PTDCheckpointer(*args, **kwargs)

    @property
    def world_size(self):
        return torch.distributed.get_world_size()

    @property
    def rank(self):
        return torch.distributed.get_rank()

    @property
    def local_rank(self):
        return int(os.environ.get("LOCAL_RANK", 0))

    @property
    def is_main_process(self):
        r"""Returns `True` if the current process is the main process on the master node."""
        return self.rank == 0

    @property
    def is_local_main_process(self):
        r"""Returns `True` if the current process is the main process on local node."""
        return self.local_rank == 0

    @property
    def device(self):
        return torch.device(_device_type, self.local_rank)

    def wait_for_everyone(self):
        return torch.distributed.barrier()

    # @contextmanager
    # def main_process_first(self):
    #     if self.is_main_process:
    #         yield
    #         self.wait_for_everyone()
    #     else:
    #         self.wait_for_everyone()
    #         yield

    def destroy(self):
        if self.is_main_process and self.tracker is not None:
            self.tracker.finish()
        return torch.distributed.destroy_process_group()

    @property
    def pipeline_parallel_enabled(self):
        return self._pp_degree > 1

    @property
    def data_parallel_enabled(self):
        return self._dp_degree > 1 or self._dp_shards > 1

    @property
    def data_replication_enabled(self):
        return self._dp_degree > 1

    @property
    def data_sharding_enabled(self):
        return self._dp_shards > 1

    @property
    def context_parallel_enabled(self):
        return self._cp_degree > 1

    @property
    def tensor_parallel_enabled(self):
        return self._tp_degree > 1


class ModelWrapper(torch.distributed.checkpoint.stateful.Stateful):
    def __init__(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None:
        self.model = [model] if isinstance(model, torch.nn.Module) else model

    def state_dict(self) -> Dict[str, Any]:
        return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        func = functools.partial(
            set_model_state_dict,
            model_state_dict=state_dict,
            options=StateDictOptions(strict=False),
        )
        list(map(func, self.model))


class PTDCheckpointer(BaseCheckpointer):
    def __init__(
        self,
        dataloader: torch.utils.data.DataLoader,
        model_parts: List[torch.nn.Module],
        optimizers: "optimizer.OptimizerWrapper",
        schedulers: "optimizer.SchedulerWrapper",
        states: Dict[str, Any],
        checkpointing_steps: int,
        checkpointing_limit: int,
        output_dir: str,
        enable: bool = True,
        _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
        _prefix: str = "finetrainers_step",
    ) -> None:
        self.states = states
        self.states.update(
            {
                "model": ModelWrapper(model_parts),
                "optimizer": optimizers,
                "dataloader": dataloader,
            }
        )
        self.states.update(schedulers.get_lr_scheduler_state())

        self.checkpointing_steps = checkpointing_steps
        self.checkpointing_limit = checkpointing_limit
        self.output_dir = pathlib.Path(output_dir)
        self.enable = enable
        self._callback_fn = _callback_fn
        self._prefix = _prefix

        logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'")

    def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str:
        if not self._should_checkpoint(step, force):
            return None

        checkpoint_dir = self._get_checkpoint_dir(step)
        begin_time = time.monotonic()
        torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix())
        end_time = time.monotonic()
        logger.info(
            f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}"
        )
        self._purge_stale_checkpoints()

        state_dicts = [
            gather_state_dict_on_cpu_rank0(model, _device, is_main_process=_is_main_process)
            for model in self.states["model"].model
        ]
        if self._callback_fn is not None:
            list(map(self._callback_fn, state_dicts))

        return checkpoint_dir.as_posix()

    def load(self, step: int = -1) -> bool:
        if not self.enable:
            return False
        if not self.output_dir.exists():
            return False
        if step != -1 and not self._get_checkpoint_dir(step).exists():
            return False

        if step == -1:
            latest_checkpoint_dir = self._find_latest_checkpoint_dir()
            if latest_checkpoint_dir is None:
                return False
            step = int(latest_checkpoint_dir.name.split("_")[-1])

        checkpoint_dir = self._get_checkpoint_dir(step)
        logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}")

        # For step 0, optimizers/schedulers are not available as they are created during training after first step
        states = {"model": self.states["model"]} if step == 0 else self.states

        # See bug: https://github.com/pytorch/pytorch/pull/138575
        original_stateful_states = {
            k: v for k, v in states.items() if isinstance(v, torch.distributed.checkpoint.stateful.Stateful)
        }
        begin_time = time.monotonic()
        torch.distributed.checkpoint.load(states, checkpoint_id=checkpoint_dir.as_posix())
        end_time = time.monotonic()
        logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.")

        # bugfix from above: restore the original stateful objects, whose states were already updated in-place by dcp.load()
        states.update(original_stateful_states)

        return True

    def _should_checkpoint(self, step: int, force: bool) -> bool:
        if not self.enable:
            return False
        if not force:
            if step % self.checkpointing_steps != 0:
                return False
        return True

    def _get_checkpoint_dir(self, step: int) -> pathlib.Path:
        return self.output_dir / f"{self._prefix}_{step}"

    def _find_latest_checkpoint_dir(self) -> Optional[pathlib.Path]:
        checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]))
        return checkpoints[-1] if len(checkpoints) > 0 else None

    def _purge_stale_checkpoints(self) -> None:
        if self.checkpointing_limit is None or self.checkpointing_limit <= 0:
            return
        checkpoints = sorted(
            self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True
        )
        for checkpoint in checkpoints[self.checkpointing_limit :]:
            logger.info(f"Deleting stale checkpoint: {checkpoint}")
            shutil.rmtree(checkpoint, ignore_errors=True)


def gather_state_dict_on_cpu_rank0(
    model, device: Optional[torch.device] = None, *, is_main_process: bool
) -> Dict[str, Any]:
    cpu_state_dict = {}
    sharded_sd = model.state_dict()
    for param_name, param in sharded_sd.items():
        if param.is_cpu:
            # Move back to device if offloaded to CPU
            param = param.to(device)
        if hasattr(param, "_local_tensor"):
            # Gather DTensor
            param = param.full_tensor()
        if is_main_process:
            cpu_state_dict[param_name] = param.cpu()
        torch.distributed.barrier()
    return cpu_state_dict


# # Copied from pytorch (torch/distributed/checkpoint/format_utils.py) to support callbacks to modify state_dict
# def dcp_to_torch_save(
#     dcp_checkpoint_dir: Union[str, os.PathLike],
#     torch_save_path: Union[str, os.PathLike],
#     callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
# ):
#     """
#     Given a directory containing a DCP checkpoint, this function will convert it into a
#     Torch save file.

#     Args:
#         dcp_checkpoint_dir: Directory containing the DCP checkpoint.
#         torch_save_path: Filename to store the converted Torch save file.
#         callback_fn: Optional callback function that takes the state_dict as input and returns a modified state_dict.

#     .. warning::
#         To avoid OOM, it's recommended to only run this function on a single rank.
#     """
#     state_dict = {}
#     _load_state_dict(
#         state_dict,
#         storage_reader=FileSystemReader(dcp_checkpoint_dir),
#         planner=_EmptyStateDictLoadPlanner(),
#         no_dist=True,
#     )
#     if callback_fn is not None:
#         state_dict = callback_fn(state_dict)
#     torch.save(state_dict, torch_save_path)


def apply_ddp(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
    replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)


def apply_fsdp2(
    model: torch.nn.Module,
    dp_mesh: torch.distributed.device_mesh.DeviceMesh,
    param_dtype: torch.dtype,
    reduce_dtype: torch.dtype,
    output_dtype: torch.dtype,
    pp_enabled: bool = False,
    cpu_offload: bool = False,
) -> None:
    """Apply FSDP2 on a model."""
    mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True)
    fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

    if cpu_offload:
        fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True)

    def apply_fully_shard(blocks):
        for layer_index, block in enumerate(blocks):
            if pp_enabled:
                # For PP, do not reshard after forward to avoid per-microbatch
                # all-gathers, which can be expensive and non-overlapped
                reshard_after_forward = False
            else:
                # As an optimization, do not reshard after forward for the last
                # transformer block since FSDP would prefetch it immediately
                reshard_after_forward = layer_index < len(blocks) - 1
            fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward)

    for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES:
        blocks = getattr(model, transformer_block_name, None)
        if blocks is not None:
            apply_fully_shard(blocks)

    fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)


def apply_context_parallel(
    model: torch.nn.Module,
    mesh: torch.distributed.device_mesh.DeviceMesh,
    plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
) -> None:
    """Apply context parallel on a model."""
    logger.debug(f"Applying context parallel with CP mesh: {mesh}")
    model_cls = unwrap_module(model).__class__

    if plan is None:
        plan = TransformerRegistry.get(model_cls).cp_plan

    for module_id, cp_model_plan in plan.items():
        module = get_submodule_by_name(model, module_id)
        if not isinstance(module, list):
            module = [module]
        logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(module)} modules")
        for m in module:
            registry = HookRegistry.check_if_exists_or_initialize(m)
            if isinstance(cp_model_plan, list):
                # Metadata can only be a list when it is a list of CPOutput
                assert all(isinstance(x, CPOutput) for x in cp_model_plan)
                hook = ContextParallelGatherHook(cp_model_plan, mesh)
                hook_name = f"cp_output---{module_id}"
            else:
                hook = ContextParallelSplitHook(cp_model_plan, mesh)
                hook_name = f"cp_input---{module_id}"
            registry.register_hook(hook, hook_name)


class ContextParallelSplitHook(ModelHook):
    def __init__(self, metadata: ContextParallelModelPlan, mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
        super().__init__()
        self.metadata = metadata
        self.mesh = mesh

    def pre_forward(self, module, *args, **kwargs):
        args_list = list(args)

        for param_identifier, cpm in self.metadata.items():
            name = param_identifier.name
            index = param_identifier.index

            if isinstance(cpm, CPInput) and cpm.split_output:
                continue

            # Maybe the parameter was passed as a keyword argument
            is_kwarg = True
            input_val = kwargs.get(name, None)

            # If not, maybe it was passed as a positional argument
            if input_val is None and index is not None:
                if index < len(args_list):  # Ensure index is within bounds
                    input_val = args_list[index]
                    is_kwarg = False
                else:
                    logger.warning(f"Index {index} out of bounds for args of length {len(args_list)}.")
                    continue  # Skip if index is invalid

            # Either the input_val is truly None, or argument is passed as normal argument
            # but user forgot to specify the index when registering metadata
            if input_val is None:
                continue

            # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
            # the output instead of input for a particular layer by setting split_output=True
            if torch.is_tensor(input_val):
                input_val = self._prepare_cp_input(input_val, cpm)

            elif isinstance(input_val, (list, tuple)):
                if len(input_val) != len(cpm):
                    raise ValueError(
                        f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
                    )
                sharded_input_val = []
                for i, x in enumerate(input_val):
                    if torch.is_tensor(x) and not cpm[i].split_output:
                        x = self._prepare_cp_input(x, cpm[i])
                    sharded_input_val.append(x)
                input_val = sharded_input_val

            else:
                raise ValueError(f"Unsupported input type: {type(input_val)}")

            if is_kwarg:
                kwargs[name] = input_val
            elif index is not None and index < len(args_list):
                args_list[index] = input_val

        return tuple(args_list), kwargs

    def post_forward(self, module, output):
        is_tensor = torch.is_tensor(output)
        is_tensor_list = isinstance(output, (list, tuple)) and all(torch.is_tensor(x) for x in output)
        if not is_tensor and not is_tensor_list:
            raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
        output = [output] if is_tensor else list(output)
        for param_identifier, cpm in self.metadata.items():
            if not isinstance(cpm, CPInput) or not cpm.split_output:
                continue
            index = param_identifier.index
            if index >= len(output):
                raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
            current_output = output[index]
            current_output = self._prepare_cp_input(current_output, cpm)
            output[index] = current_output
        return output[0] if is_tensor else tuple(output)

    def _prepare_cp_input(self, x: torch.Tensor, cp_input: CPInput) -> torch.Tensor:
        if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
            raise ValueError(
                f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
            )
        return _EquipartitionSharder.shard(x, cp_input.split_dim, self.mesh)


class ContextParallelGatherHook(ModelHook):
    def __init__(self, metadata: ContextParallelModelPlan, mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
        super().__init__()
        self.metadata = metadata
        self.mesh = mesh

    def post_forward(self, module, output):
        is_tensor = torch.is_tensor(output)
        if is_tensor:
            output = [output]
        output = list(output)
        assert len(output) == len(self.metadata), f"Expected {len(self.metadata)} outputs, but got {len(output)}."
        for i, cpm in enumerate(self.metadata):
            if cpm is None:
                continue
            output[i] = _EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.mesh)
        return output[0] if is_tensor else tuple(output)


class _ContextParallelSharder:
    @classmethod
    def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
        raise NotImplementedError("_ContextParallelSharder::shard should be implemented in subclasses")

    @classmethod
    def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
        raise NotImplementedError("_ContextParallelSharder::unshard should be implemented in subclasses")


class _EquipartitionSharder(_ContextParallelSharder):
    """
    Shards the input tensor along the specified dimension into cp_mesh's world size chunks.
    Essentially, rank_i gets the i-th chunk.

    This sharding strategy should only be used when performing full attention. Otherwise, it will
    have performance penalty. If using causal attention, please use _CausalSharder instead.
    """

    @classmethod
    def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
        assert tensor.size()[dim] % mesh.size() == 0
        return tensor.chunk(mesh.size(), dim=dim)[mesh.get_local_rank()]

    @classmethod
    def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
        tensor = tensor.contiguous()
        # TODO(aryan): pass a shape here so that we can allow uneven sharding across seq dim
        result = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor()
        return result


# TODO(aryan): this class is untested
class _CausalSharder(_ContextParallelSharder):
    """
    Shards the input tensor along the specified dimension into 2x cp_mesh's world size chunks.
    Essentially, rank_i gets the i-th chunk and (2 * cp_world_size - 1 - i)-th chunk.

    This sharding strategy improves the performance for causal attention, as it allows
    equal distribution of computation across all ranks.

    Causal attention mask:
    ```
    1 0 0 0    <--- Group 0
    1 1 0 0    <--- Group 1
    1 1 1 0    <--- Group 1
    1 1 1 1    <--- Group 0
    ```
    """

    @classmethod
    def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
        world_size = mesh.size()
        rank = mesh.get_local_rank()
        assert tensor.size()[dim] % (2 * world_size) == 0
        chunks = tensor.chunk(2 * world_size, dim=dim)
        i, j = rank, 2 * world_size - 1 - rank
        return torch.cat((chunks[i], chunks[j]), dim=dim)

    @classmethod
    def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
        tensor = tensor.contiguous()
        world_size = mesh.size()
        # TODO(aryan): pass a shape here so that we can allow uneven sharding across seq dim
        all_tensors = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor()
        sliced_tensors = [st for t in all_tensors for st in t.chunk(2, dim=dim)]
        ordered_tensors = list(sliced_tensors)
        for i, t in enumerate(sliced_tensors):
            if i % 2 == 0:
                ordered_tensors[i // 2] = t
            else:
                ordered_tensors[world_size * 2 - (i // 2) - 1] = t
        return torch.cat(ordered_tensors, dim=dim)