Spaces:
Running
Running
File size: 5,654 Bytes
9fd1204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional
import torch
from finetrainers.trackers import DummyTracker, TrackerType, initialize_trackers
class BaseParallelBackend:
r"""
Base class that contains properties and methods that should be implemented by different parallel backends.
"""
def __init__(self):
self.tracker = None
def enable_determinism(self, seed: int) -> None:
raise NotImplementedError("Method `enable_determinism` must be implemented by subclass.")
def apply_ddp(self, *args, **kwargs) -> torch.nn.Module:
raise NotImplementedError("Method `apply_ddp` must be implemented by subclass.")
def apply_fsdp2(self, *args, **kwargs) -> torch.nn.Module:
raise NotImplementedError("Method `apply_fsdp2` must be implemented by subclass.")
def apply_context_parallel(self, *args, **kwargs) -> torch.nn.Module:
raise NotImplementedError("Method `apply_context_parallel` must be implemented by subclass.")
def prepare_model(self, *args, **kwargs) -> Any:
raise NotImplementedError("Method `prepare_model` must be implemented by subclass.")
def prepare_dataset(self, *args, **kwargs) -> Any:
raise NotImplementedError("Method `prepare_dataset` must be implemented by subclass.")
def prepare_dataloader(self, *args, **kwargs) -> Any:
raise NotImplementedError("Method `prepare_dataloader` must be implemented by subclass.")
def prepare_optimizer(self, *args, **kwargs) -> Any:
raise NotImplementedError("Method `prepare_optimizer` must be implemented by subclass.")
def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
raise NotImplementedError("Method `get_mesh` must be implemented by subclass.")
def get_checkpointer(self, *args, **kwargs) -> None:
raise NotImplementedError("Method `get_checkpointer` must be implemented by subclass.")
def initialize_trackers(
self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
) -> TrackerType:
if self.is_main_process:
self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir)
else:
self.tracker = DummyTracker()
def log(self, metrics: Dict[str, Any], step: int) -> None:
if self.is_main_process:
self.tracker.log(metrics, step)
def wait_for_everyone(self):
raise NotImplementedError("Method `wait_for_everyone` must be implemented by subclass.")
@contextmanager
def main_process_first(self):
raise NotImplementedError("Method `main_process_first` must be implemented by subclass.")
def destroy(self):
raise NotImplementedError("Method `destroy` must be implemented by subclass.")
@property
def world_size(self):
raise NotImplementedError("Method `world_size` must be implemented by subclass.")
@property
def rank(self):
raise NotImplementedError("Method `rank` must be implemented by subclass.")
@property
def local_rank(self):
raise NotImplementedError("Method `local_rank` must be implemented by subclass.")
@property
def is_main_process(self):
raise NotImplementedError("Method `is_main_process` must be implemented by subclass.")
@property
def is_local_main_process(self):
raise NotImplementedError("Method `is_local_main_process` must be implemented by subclass.")
@property
def device(self):
raise NotImplementedError("Method `device` must be implemented by subclass.")
@property
def pipeline_parallel_enabled(self):
raise NotImplementedError("Property `pipeline_parallel_enabled` must be implemented by subclass.")
@property
def data_parallel_enabled(self):
raise NotImplementedError("Property `data_parallel_enabled` must be implemented by subclass.")
@property
def data_replication_enabled(self):
raise NotImplementedError("Property `data_replication_enabled` must be implemented by subclass.")
@property
def data_sharding_enabled(self):
raise NotImplementedError("Property `data_sharding_enabled` must be implemented by subclass.")
@property
def context_parallel_enabled(self):
raise NotImplementedError("Property `context_parallel_enabled` must be implemented by subclass.")
@property
def tensor_parallel_enabled(self):
raise NotImplementedError("Property `tensor_parallel_enabled` must be implemented by subclass.")
class BaseCheckpointer:
r"""
Base class that contains properties and methods that should be implemented by different parallel backends.
"""
def __init__(
self,
dataloader: torch.utils.data.DataLoader,
model_parts: List[torch.nn.Module],
optimizers: Any,
schedulers: Any,
states: Dict[str, Any],
checkpointing_steps: int,
checkpointing_limit: int,
output_dir: str,
enable: bool = True,
_callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
_prefix: str = "finetrainers_step",
*args,
**kwargs,
) -> None:
raise NotImplementedError("Method `__init__` must be implemented by subclass.")
def save(self, step: int, force: bool, *, _device: Optional[torch.device] = None, _is_main_process: bool) -> str:
raise NotImplementedError("Method `save` must be implemented by subclass.")
def load(self, step: int = -1) -> bool:
raise NotImplementedError("Method `load` must be implemented by subclass.")
|