Spaces:
Running
Running
import torch | |
import torch.distributed._functional_collectives as funcol | |
import torch.distributed.tensor | |
def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: | |
if isinstance(x, torch.distributed.tensor.DTensor): | |
# functional collectives do not support DTensor inputs | |
x = x.full_tensor() | |
assert x.numel() == 1 # required by `.item()` | |
return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() | |
def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: | |
return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh) | |
def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: | |
return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh) | |