ghostv1 / apex /tests /L0 /run_optimizers /test_dist_adam.py
Jagrut Thakare
v1
6d92c79
raw
history blame
5.92 kB
import argparse
import random
import sys
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from apex import amp
from apex.optimizers import FusedAdam
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
class TestModel(torch.nn.Module):
def __init__(self, args):
super(TestModel, self).__init__()
self.linear = torch.nn.Sequential(*[torch.nn.Linear(args.dim, args.dim, bias=args.bias) for _ in range(args.layers)])
def forward(self, x):
return self.linear(x)
def setup(args):
## Model
ref_model = TestModel(args).cuda()
dist_model = TestModel(args).cuda()
# Same weights
with torch.no_grad():
for dp, rp in zip(dist_model.parameters(), ref_model.parameters()):
dp.data.copy_(rp.data)
dist_model = dist_model.half()
## Optimizer
# same hyperparameters
ref_opt_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
ref_opt = FusedAdam(ref_model.parameters(), **ref_opt_args)
dist_opt_args = ref_opt_args.copy()
dist_opt_args.update( {'overlap_reductions' : False} )
dist_opt_args.update( {'process_group_size' : args.n_gpu} )
dist_opt_args.update( {'dwu_group_size' : args.dwu_group_size} )
dist_opt_args.update( {'dwu_num_blocks' : 1} )
dist_opt_args.update( {'dwu_num_chunks' : 1} )
dist_opt = DistributedFusedAdam(dist_model.parameters(), **dist_opt_args)
dist_opt.set_global_scale(1.)
## amp-init
amp_args = { 'loss_scale' : 'dynamic' , 'opt_level' : 'O2'}
ref_model, ref_opt = amp.initialize(ref_model, ref_opt, **amp_args)
## DDP
ref_model = DDP(ref_model, device_ids=[args.rank])
with torch.no_grad():
for dp in dist_model.parameters():
torch.distributed.broadcast(dp.data, src=0)
for rp in ref_model.parameters():
torch.distributed.broadcast(rp.data, src=0)
torch.cuda.synchronize()
torch.distributed.barrier()
if get_rank() == 0:
print(f'dist opt with {args.n_gpu} GPUs')
return ref_model, ref_opt, dist_model, dist_opt
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--steps', type=int, default=20)
parser.add_argument('--batch', type=int, default=32)
parser.add_argument('--dim', type=int, default=4)
parser.add_argument('--layers', type=int, default=2)
parser.add_argument('--bias', action='store_true')
parser.add_argument('--atol', type=float, default=1e-3)
parser.add_argument('--rtol', type=float, default=1)
parser.add_argument('--dwu_group_size', type=float, default=1)
args = parser.parse_args()
return args
def setup_env(args):
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.rank = torch.distributed.get_rank()
args.n_gpu = torch.distributed.get_world_size()
seed = 42 + get_rank()
random.seed(seed)
torch.manual_seed(seed)
return args
def get_rank():
return torch.distributed.get_rank()
def main():
args = parse_args()
args = setup_env(args)
tol_args = { 'atol' : args.atol, 'rtol' : args.rtol }
torch.set_printoptions(precision=16)
ref_model, ref_opt, dist_model, dist_opt = setup(args)
# lazy_init not called yet, initialize stash
stash = ref_opt._amp_stash
stash.all_fp16_params, stash.all_fp32_from_fp16_params = [], []
# make sure everything from _first_step_init_ is ready before training
# e.g. registering allreduce_hook
# so that gradients are copied/reduced when necessary
dist_opt._init_everything()
for i in range(args.steps):
x_ref = torch.randn(args.batch, args.dim, dtype=torch.half).cuda().requires_grad_(True)
x_dist = x_ref.clone().detach().requires_grad_(True)
if get_rank() == 0:
print(f'[{i}] Checking input')
#print("x_ref:", x_ref.flatten()[:10])
#print("x_dist:", x_dist.flatten()[:10])
assert(torch.allclose(x_ref, x_dist, **tol_args))
y_ref = ref_model(x_ref).half()
y_dist = dist_model(x_dist)
if get_rank() == 0:
print(f'[{i}] Checking output')
#print("y_ref:", y_ref.flatten()[:10])
#print("y_dist:", y_dist.flatten()[:10])
assert(torch.allclose(y_ref, y_dist, **tol_args))
dy = torch.randn_like(y_ref)
y_ref.backward(dy)
y_dist.backward(dy)
if get_rank() == 0:
print(f'[{i}] Checking gradients')
torch.distributed.barrier()
torch.cuda.synchronize()
assert(torch.allclose(x_ref.grad, x_dist.grad, **tol_args))
# gradient all-reduce within distributed optimizer
dist_opt.complete_reductions()
if get_rank() == 0:
print(f'[{i}] Stepping')
ref_opt.step()
dist_opt.step()
torch.cuda.synchronize()
torch.distributed.barrier()
print('Checking new weights')
if get_rank() == 0:
print("ref param:", ref_model.module.linear[0].weight)
print("dist param:", dist_model.linear[0].weight)
for i, (rp, dp) in enumerate(zip(ref_model.parameters(), dist_model.parameters())):
if not torch.allclose(rp, dp, **tol_args):
if get_rank() == 0:
print(f'Rank: {get_rank()}, Param: {i}')
print(f'ref: {rp.sum().item()}, dist: {dp.sum().item()}')
print(rp)
print(dp)
print(torch.abs(rp-dp) > tol_args['atol'])
sys.exit(0)
# zero grads
for rp, dp in zip(ref_model.parameters(), dist_model.parameters()):
rp.grad = None
dp.grad = None
if __name__ == "__main__":
main()