# Adapted from https://pytorch.org/docs/stable/_modules/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.html # We divide by world_size first before converting to fp16, so it's safer. from typing import Any, Callable import torch import torch.distributed as dist def fp16_compress_hook( process_group: dist.ProcessGroup, bucket: dist.GradBucket ) -> torch.futures.Future[torch.Tensor]: """ This DDP communication hook implements a simple gradient compression approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``) and then divides it by the process group size. It allreduces those ``float16`` gradient tensors. Once compressed gradient tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). Example:: >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD world_size = group_to_use.size() # Divide first before converting to fp16 # Use out argument to fuse the division and the conversion. compressed_tensor = torch.div(bucket.buffer(), world_size, out=torch.empty_like(bucket.buffer(), dtype=torch.float16)) fut = dist.all_reduce( compressed_tensor, group=group_to_use, async_op=True ).get_future() def decompress(fut): decompressed_tensor = bucket.buffer() # Decompress in place to reduce the peak memory. # See: https://github.com/pytorch/pytorch/issues/45968 decompressed_tensor.copy_(fut.value()[0]) return decompressed_tensor # TODO: maybe have a backoff strategy: check if the buffer has inf / NaN, in that case # resend with fp32? return fut.then(decompress)