import torch import torch.distributed._functional_collectives as funcol import torch.distributed.tensor def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: if isinstance(x, torch.distributed.tensor.DTensor): # functional collectives do not support DTensor inputs x = x.full_tensor() assert x.numel() == 1 # required by `.item()` return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh) def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh)