Spaces:
Runtime error
Runtime error
import os | |
import random | |
from datetime import timedelta | |
import numpy as np | |
import torch | |
import torch.backends.cudnn as cudnn | |
import torch.distributed as dist | |
import torch.nn as nn | |
from torch.distributed import checkpoint as dist_checkpoint | |
from torch.distributed import fsdp | |
import functools | |
import itertools | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.utils.data import Dataset | |
from typing import Any, Dict, Optional | |
from surya.utils.schemas import TrainState | |
def init_dist(device: str, rank: int, world_size: int): | |
torch.distributed.init_process_group( | |
device, | |
init_method="env://", | |
world_size=world_size, | |
rank=rank, | |
timeout=timedelta(minutes=60), | |
) | |
def init_ddp(use_gpu: bool): | |
local_rank = int(os.environ["LOCAL_RANK"]) | |
rank = int(os.environ["RANK"]) | |
world_size = int(os.environ["WORLD_SIZE"]) | |
if use_gpu: | |
assert ( | |
torch.cuda.is_available() | |
), "GPU requested but none was found in the system." | |
if use_gpu: | |
init_dist("nccl", rank, world_size) | |
torch.cuda.set_device(local_rank) | |
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) | |
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = str(1) | |
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" | |
cudnn.benchmark = True | |
else: | |
init_dist("gloo", rank, world_size) | |
return local_rank, rank | |
def set_global_seed(rank): | |
random.seed(42 + rank) | |
torch.cuda.manual_seed(42 + rank) | |
torch.manual_seed(42 + rank) | |
np.random.seed(42 + rank) | |
def is_dist_avail_and_initialized(): | |
if not dist.is_available(): | |
return False | |
if not dist.is_initialized(): | |
return False | |
return True | |
def get_world_size(): | |
if not is_dist_avail_and_initialized(): | |
return 1 | |
return dist.get_world_size() | |
def get_rank(): | |
if not is_dist_avail_and_initialized(): | |
return 0 | |
return dist.get_rank() | |
def is_main_process(): | |
return get_rank() == 0 | |
# def save_model_singular(model, *args, **kwargs): | |
# """Stream all model parameters to rank 0 on the CPU, then pass all | |
# other given arguments to `torch.save` to save the model, but only on | |
# the root process. | |
# """ | |
# save_policy = fsdp.FullStateDictConfig( | |
# offload_to_cpu=True, rank0_only=True) | |
# with fsdp.FullyShardedDataParallel.state_dict_type( | |
# model, | |
# fsdp.StateDictType.FULL_STATE_DICT, | |
# save_policy, | |
# ): | |
# cpu_state = model.state_dict() | |
# # We do *not* want to write to the same location with multiple | |
# # processes at the same time. | |
# if is_root_process(): | |
# torch.save(cpu_state, *args, **kwargs) | |
def save_model(model, save_dir): | |
"""Obtain sharded model parameters from the GPU, then save the model | |
as a distributed checkpoint to the given directory. Saving a | |
distributed checkpoint means that the checkpoint will be split into | |
individual files, one for each process. | |
""" | |
state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False) | |
with fsdp.FullyShardedDataParallel.state_dict_type( | |
model, | |
fsdp.StateDictType.SHARDED_STATE_DICT, | |
state_dict_config, | |
): | |
cp_state_dict = {"model": model.state_dict()} | |
dist_checkpoint.save_state_dict( | |
cp_state_dict, | |
dist_checkpoint.FileSystemWriter(save_dir), | |
) | |
def load_model(model, load_dir): | |
"""Set the given model's state dictionary in-place from the given | |
distributed checkpoint directory. | |
""" | |
state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False) | |
with fsdp.FullyShardedDataParallel.state_dict_type( | |
model, | |
fsdp.StateDictType.SHARDED_STATE_DICT, | |
state_dict_config, | |
): | |
cp_state_dict = {"model": model.state_dict()} | |
dist_checkpoint.load_state_dict( | |
cp_state_dict, | |
dist_checkpoint.FileSystemReader(load_dir), | |
) | |
model.load_state_dict(cp_state_dict["model"]) | |
def is_root_process(): | |
"""Return whether this process is the root process.""" | |
return torch.distributed.get_rank() == 0 | |
# The reason we define this is that `torch.distributed` does not | |
# implement it; for the global rank, there's | |
# `torch.distributed.get_rank()`. | |
def get_local_rank(): | |
"""Return the local rank of this process.""" | |
return int(os.getenv("LOCAL_RANK")) | |
def print0(*args, **kwargs): | |
"""Print something only on the root process.""" | |
if (not dist.is_initialized()) or is_root_process(): | |
print(*args, **kwargs) | |
def save_model_singular(model, save_path, parallelism, *args, **kwargs): | |
"""Stream all model parameters to rank 0 on the CPU, then pass all | |
other given arguments to `torch.save` to save the model, but only on | |
the root process. | |
""" | |
match parallelism: | |
case "fsdp": | |
save_policy = fsdp.FullStateDictConfig(offload_to_cpu=True, rank0_only=True) | |
with fsdp.FullyShardedDataParallel.state_dict_type( | |
model, | |
fsdp.StateDictType.FULL_STATE_DICT, | |
save_policy, | |
): | |
cpu_state = model.state_dict() | |
# We do *not* want to write to the same location with multiple | |
# processes at the same time. | |
if is_main_process(): | |
if not os.path.exists(os.path.dirname(save_path)): | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
torch.save(obj=cpu_state, f=save_path, *args, **kwargs) | |
case "ddp": | |
if is_main_process(): | |
torch.save(obj=model.module.state_dict(), f=save_path, *args, **kwargs) | |
dist.barrier() | |
case _: | |
raise ValueError( | |
f'`parallelism` should be one of "ddp" and "fsdp". Got {parallelism}.' | |
) | |
def save_optim_singular( | |
model: nn.Module, | |
optimizer: torch.optim.Optimizer, | |
save_path: str, | |
parallelism: str = "fsdp", | |
): | |
match parallelism: | |
case "fsdp": | |
optim_state_dict_config = fsdp.FullOptimStateDictConfig( | |
offload_to_cpu=True, rank0_only=True | |
) | |
with fsdp.FullyShardedDataParallel.state_dict_type( | |
model, | |
fsdp.StateDictType.FULL_STATE_DICT, | |
optim_state_dict_config=optim_state_dict_config, | |
): | |
optim_state_dict = fsdp.FullyShardedDataParallel.optim_state_dict( | |
model, optimizer | |
) | |
if is_main_process(): | |
if not os.path.exists(os.path.dirname(save_path)): | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
checkpoint = { | |
"optimizer_state_dict": optim_state_dict, | |
} | |
torch.save(checkpoint, f=save_path) | |
case "ddp": | |
if is_main_process(): | |
optim_state_dict = optimizer.state_dict() | |
if not os.path.exists(os.path.dirname(save_path)): | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
torch.save(obj=optim_state_dict, f=save_path) | |
dist.barrier() | |
case _: | |
raise ValueError( | |
f'`parallelism` should be one of "ddp" and "fsdp". Got {parallelism}.' | |
) | |
def collect_optim_singular( | |
model: nn.Module, optimizer: torch.optim.Optimizer, parallelism: str = "fsdp" | |
) -> dict: | |
optim_state_dict = {} | |
match parallelism: | |
case "fsdp": | |
optim_state_dict_config = fsdp.FullOptimStateDictConfig( | |
offload_to_cpu=True, rank0_only=True | |
) | |
with fsdp.FullyShardedDataParallel.state_dict_type( | |
model, | |
fsdp.StateDictType.FULL_STATE_DICT, | |
optim_state_dict_config=optim_state_dict_config, | |
): | |
optim_state_dict = fsdp.FullyShardedDataParallel.optim_state_dict( | |
model, optimizer | |
) | |
case "ddp": | |
if is_main_process(): | |
optim_state_dict = optimizer.state_dict() | |
dist.barrier() | |
case _: | |
raise ValueError( | |
f'`parallelism` should be one of "ddp" and "fsdp". Got {parallelism}.' | |
) | |
return optim_state_dict | |
def save_state_singular(states: TrainState, save_path, *args, **kwargs): | |
"""Stream all model parameters to rank 0 on the CPU, then pass all | |
other given arguments to `torch.save` to save paramters, but only on | |
the root process. | |
""" | |
if is_main_process(): | |
if not os.path.exists(os.path.dirname(save_path)): | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
torch.save(obj=states, f=save_path, *args, **kwargs) | |
dist.barrier() | |
class StatefulDistributedSampler(DistributedSampler): | |
_YIELDED = "yielded" | |
def __init__( | |
self, | |
dataset: Dataset, | |
num_replicas: Optional[int] = None, | |
rank: Optional[int] = None, | |
shuffle: bool = True, | |
seed: int = 0, | |
drop_last: bool = False, | |
) -> None: | |
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) | |
self.yielded = 0 | |
self.next_yielded = None | |
def __iter__(self): | |
self.yielded = 0 | |
if self.next_yielded is not None: | |
self.yielded = self.next_yielded | |
self.next_yielded = None | |
it = super().__iter__() | |
for idx in itertools.islice(it, self.yielded, None): | |
self.yielded += 1 | |
yield idx | |
def state_dict(self) -> Dict[str, Any]: | |
return {self._YIELDED: self.yielded} | |
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
if self._YIELDED not in state_dict: | |
raise ValueError("Invalid state_dict") | |
if state_dict[self._YIELDED] < 0: | |
raise ValueError("Cannot load state_dict with negative yielded value") | |
self.next_yielded = state_dict[self._YIELDED] | |