ghostv1 / apex /tests /L0 /run_optimizers /test_dist_adam.py
Jagrut Thakare
v1
6d92c79
import argparse
import random
import sys
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from apex import amp
from apex.optimizers import FusedAdam
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
class TestModel(torch.nn.Module):
def __init__(self, args):
super(TestModel, self).__init__()
self.linear = torch.nn.Sequential(*[torch.nn.Linear(args.dim, args.dim, bias=args.bias) for _ in range(args.layers)])
def forward(self, x):
return self.linear(x)
def setup(args):
## Model
ref_model = TestModel(args).cuda()
dist_model = TestModel(args).cuda()
# Same weights
with torch.no_grad():
for dp, rp in zip(dist_model.parameters(), ref_model.parameters()):
dp.data.copy_(rp.data)
dist_model = dist_model.half()
## Optimizer
# same hyperparameters
ref_opt_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
ref_opt = FusedAdam(ref_model.parameters(), **ref_opt_args)
dist_opt_args = ref_opt_args.copy()
dist_opt_args.update( {'overlap_reductions' : False} )
dist_opt_args.update( {'process_group_size' : args.n_gpu} )
dist_opt_args.update( {'dwu_group_size' : args.dwu_group_size} )
dist_opt_args.update( {'dwu_num_blocks' : 1} )
dist_opt_args.update( {'dwu_num_chunks' : 1} )
dist_opt = DistributedFusedAdam(dist_model.parameters(), **dist_opt_args)
dist_opt.set_global_scale(1.)
## amp-init
amp_args = { 'loss_scale' : 'dynamic' , 'opt_level' : 'O2'}
ref_model, ref_opt = amp.initialize(ref_model, ref_opt, **amp_args)
## DDP
ref_model = DDP(ref_model, device_ids=[args.rank])
with torch.no_grad():
for dp in dist_model.parameters():
torch.distributed.broadcast(dp.data, src=0)
for rp in ref_model.parameters():
torch.distributed.broadcast(rp.data, src=0)
torch.cuda.synchronize()
torch.distributed.barrier()
if get_rank() == 0:
print(f'dist opt with {args.n_gpu} GPUs')
return ref_model, ref_opt, dist_model, dist_opt
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--steps', type=int, default=20)
parser.add_argument('--batch', type=int, default=32)
parser.add_argument('--dim', type=int, default=4)
parser.add_argument('--layers', type=int, default=2)
parser.add_argument('--bias', action='store_true')
parser.add_argument('--atol', type=float, default=1e-3)
parser.add_argument('--rtol', type=float, default=1)
parser.add_argument('--dwu_group_size', type=float, default=1)
args = parser.parse_args()
return args
def setup_env(args):
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.rank = torch.distributed.get_rank()
args.n_gpu = torch.distributed.get_world_size()
seed = 42 + get_rank()
random.seed(seed)
torch.manual_seed(seed)
return args
def get_rank():
return torch.distributed.get_rank()
def main():
args = parse_args()
args = setup_env(args)
tol_args = { 'atol' : args.atol, 'rtol' : args.rtol }
torch.set_printoptions(precision=16)
ref_model, ref_opt, dist_model, dist_opt = setup(args)
# lazy_init not called yet, initialize stash
stash = ref_opt._amp_stash
stash.all_fp16_params, stash.all_fp32_from_fp16_params = [], []
# make sure everything from _first_step_init_ is ready before training
# e.g. registering allreduce_hook
# so that gradients are copied/reduced when necessary
dist_opt._init_everything()
for i in range(args.steps):
x_ref = torch.randn(args.batch, args.dim, dtype=torch.half).cuda().requires_grad_(True)
x_dist = x_ref.clone().detach().requires_grad_(True)
if get_rank() == 0:
print(f'[{i}] Checking input')
#print("x_ref:", x_ref.flatten()[:10])
#print("x_dist:", x_dist.flatten()[:10])
assert(torch.allclose(x_ref, x_dist, **tol_args))
y_ref = ref_model(x_ref).half()
y_dist = dist_model(x_dist)
if get_rank() == 0:
print(f'[{i}] Checking output')
#print("y_ref:", y_ref.flatten()[:10])
#print("y_dist:", y_dist.flatten()[:10])
assert(torch.allclose(y_ref, y_dist, **tol_args))
dy = torch.randn_like(y_ref)
y_ref.backward(dy)
y_dist.backward(dy)
if get_rank() == 0:
print(f'[{i}] Checking gradients')
torch.distributed.barrier()
torch.cuda.synchronize()
assert(torch.allclose(x_ref.grad, x_dist.grad, **tol_args))
# gradient all-reduce within distributed optimizer
dist_opt.complete_reductions()
if get_rank() == 0:
print(f'[{i}] Stepping')
ref_opt.step()
dist_opt.step()
torch.cuda.synchronize()
torch.distributed.barrier()
print('Checking new weights')
if get_rank() == 0:
print("ref param:", ref_model.module.linear[0].weight)
print("dist param:", dist_model.linear[0].weight)
for i, (rp, dp) in enumerate(zip(ref_model.parameters(), dist_model.parameters())):
if not torch.allclose(rp, dp, **tol_args):
if get_rank() == 0:
print(f'Rank: {get_rank()}, Param: {i}')
print(f'ref: {rp.sum().item()}, dist: {dp.sum().item()}')
print(rp)
print(dp)
print(torch.abs(rp-dp) > tol_args['atol'])
sys.exit(0)
# zero grads
for rp, dp in zip(ref_model.parameters(), dist_model.parameters()):
rp.grad = None
dp.grad = None
if __name__ == "__main__":
main()