Yunus Serhat Bıçakçı commited on
Commit
4ce6fd4
·
1 Parent(s): d24583c
Files changed (1) hide show
  1. pages/4_LLM.py +1 -1
pages/4_LLM.py CHANGED
@@ -39,7 +39,7 @@ def setup_model_parallel() -> Tuple[int, int]:
39
  local_rank = int(os.environ.get("LOCAL_RANK", -1))
40
  world_size = int(os.environ.get("WORLD_SIZE", -1))
41
 
42
- torch.distributed.init_process_group("nccl")
43
  initialize_model_parallel(world_size)
44
  torch.cuda.set_device(local_rank)
45
 
 
39
  local_rank = int(os.environ.get("LOCAL_RANK", -1))
40
  world_size = int(os.environ.get("WORLD_SIZE", -1))
41
 
42
+ torch.distributed.init_process_group("mpi")
43
  initialize_model_parallel(world_size)
44
  torch.cuda.set_device(local_rank)
45