jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
raw
history blame contribute delete
861 Bytes
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor
def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
if isinstance(x, torch.distributed.tensor.DTensor):
# functional collectives do not support DTensor inputs
x = x.full_tensor()
assert x.numel() == 1 # required by `.item()`
return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()
def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh)
def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh)