Spaces:
Running
Running
import argparse | |
import random | |
import sys | |
import torch | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from apex import amp | |
from apex.optimizers import FusedAdam | |
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam | |
class TestModel(torch.nn.Module): | |
def __init__(self, args): | |
super(TestModel, self).__init__() | |
self.linear = torch.nn.Sequential(*[torch.nn.Linear(args.dim, args.dim, bias=args.bias) for _ in range(args.layers)]) | |
def forward(self, x): | |
return self.linear(x) | |
def setup(args): | |
## Model | |
ref_model = TestModel(args).cuda() | |
dist_model = TestModel(args).cuda() | |
# Same weights | |
with torch.no_grad(): | |
for dp, rp in zip(dist_model.parameters(), ref_model.parameters()): | |
dp.data.copy_(rp.data) | |
dist_model = dist_model.half() | |
## Optimizer | |
# same hyperparameters | |
ref_opt_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 } | |
ref_opt = FusedAdam(ref_model.parameters(), **ref_opt_args) | |
dist_opt_args = ref_opt_args.copy() | |
dist_opt_args.update( {'overlap_reductions' : False} ) | |
dist_opt_args.update( {'process_group_size' : args.n_gpu} ) | |
dist_opt_args.update( {'dwu_group_size' : args.dwu_group_size} ) | |
dist_opt_args.update( {'dwu_num_blocks' : 1} ) | |
dist_opt_args.update( {'dwu_num_chunks' : 1} ) | |
dist_opt = DistributedFusedAdam(dist_model.parameters(), **dist_opt_args) | |
dist_opt.set_global_scale(1.) | |
## amp-init | |
amp_args = { 'loss_scale' : 'dynamic' , 'opt_level' : 'O2'} | |
ref_model, ref_opt = amp.initialize(ref_model, ref_opt, **amp_args) | |
## DDP | |
ref_model = DDP(ref_model, device_ids=[args.rank]) | |
with torch.no_grad(): | |
for dp in dist_model.parameters(): | |
torch.distributed.broadcast(dp.data, src=0) | |
for rp in ref_model.parameters(): | |
torch.distributed.broadcast(rp.data, src=0) | |
torch.cuda.synchronize() | |
torch.distributed.barrier() | |
if get_rank() == 0: | |
print(f'dist opt with {args.n_gpu} GPUs') | |
return ref_model, ref_opt, dist_model, dist_opt | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--local_rank', type=int, default=-1) | |
parser.add_argument('--steps', type=int, default=20) | |
parser.add_argument('--batch', type=int, default=32) | |
parser.add_argument('--dim', type=int, default=4) | |
parser.add_argument('--layers', type=int, default=2) | |
parser.add_argument('--bias', action='store_true') | |
parser.add_argument('--atol', type=float, default=1e-3) | |
parser.add_argument('--rtol', type=float, default=1) | |
parser.add_argument('--dwu_group_size', type=float, default=1) | |
args = parser.parse_args() | |
return args | |
def setup_env(args): | |
torch.cuda.set_device(args.local_rank) | |
torch.distributed.init_process_group(backend='nccl', init_method='env://') | |
args.rank = torch.distributed.get_rank() | |
args.n_gpu = torch.distributed.get_world_size() | |
seed = 42 + get_rank() | |
random.seed(seed) | |
torch.manual_seed(seed) | |
return args | |
def get_rank(): | |
return torch.distributed.get_rank() | |
def main(): | |
args = parse_args() | |
args = setup_env(args) | |
tol_args = { 'atol' : args.atol, 'rtol' : args.rtol } | |
torch.set_printoptions(precision=16) | |
ref_model, ref_opt, dist_model, dist_opt = setup(args) | |
# lazy_init not called yet, initialize stash | |
stash = ref_opt._amp_stash | |
stash.all_fp16_params, stash.all_fp32_from_fp16_params = [], [] | |
# make sure everything from _first_step_init_ is ready before training | |
# e.g. registering allreduce_hook | |
# so that gradients are copied/reduced when necessary | |
dist_opt._init_everything() | |
for i in range(args.steps): | |
x_ref = torch.randn(args.batch, args.dim, dtype=torch.half).cuda().requires_grad_(True) | |
x_dist = x_ref.clone().detach().requires_grad_(True) | |
if get_rank() == 0: | |
print(f'[{i}] Checking input') | |
#print("x_ref:", x_ref.flatten()[:10]) | |
#print("x_dist:", x_dist.flatten()[:10]) | |
assert(torch.allclose(x_ref, x_dist, **tol_args)) | |
y_ref = ref_model(x_ref).half() | |
y_dist = dist_model(x_dist) | |
if get_rank() == 0: | |
print(f'[{i}] Checking output') | |
#print("y_ref:", y_ref.flatten()[:10]) | |
#print("y_dist:", y_dist.flatten()[:10]) | |
assert(torch.allclose(y_ref, y_dist, **tol_args)) | |
dy = torch.randn_like(y_ref) | |
y_ref.backward(dy) | |
y_dist.backward(dy) | |
if get_rank() == 0: | |
print(f'[{i}] Checking gradients') | |
torch.distributed.barrier() | |
torch.cuda.synchronize() | |
assert(torch.allclose(x_ref.grad, x_dist.grad, **tol_args)) | |
# gradient all-reduce within distributed optimizer | |
dist_opt.complete_reductions() | |
if get_rank() == 0: | |
print(f'[{i}] Stepping') | |
ref_opt.step() | |
dist_opt.step() | |
torch.cuda.synchronize() | |
torch.distributed.barrier() | |
print('Checking new weights') | |
if get_rank() == 0: | |
print("ref param:", ref_model.module.linear[0].weight) | |
print("dist param:", dist_model.linear[0].weight) | |
for i, (rp, dp) in enumerate(zip(ref_model.parameters(), dist_model.parameters())): | |
if not torch.allclose(rp, dp, **tol_args): | |
if get_rank() == 0: | |
print(f'Rank: {get_rank()}, Param: {i}') | |
print(f'ref: {rp.sum().item()}, dist: {dp.sum().item()}') | |
print(rp) | |
print(dp) | |
print(torch.abs(rp-dp) > tol_args['atol']) | |
sys.exit(0) | |
# zero grads | |
for rp, dp in zip(ref_model.parameters(), dist_model.parameters()): | |
rp.grad = None | |
dp.grad = None | |
if __name__ == "__main__": | |
main() | |