Spaces:
Sleeping
Sleeping
# @package _global_ | |
defaults: | |
- /experiment/owt/gpt2m-flash.yaml | |
- override /model/gpt2model: gpt2-large | |
# TD [2022-08-03] Surprisingly it's faster to use the ZeRO optimizer than just AdamW. | |
# Still, fairscale is even faster and uses less memory. | |
# I think it's because Pytorch is using ZeRO stage 1 and fairscale is using ZeRO stage 2? | |
# However, fairscale has issues with saving checkpoint (either OOM or very | |
# slow since it goes through the CPU?). Fairscale says Pytorch ZeRO is the | |
# upstream version of OSS | |
# https://github.com/facebookresearch/fairscale/issues/937 | |
# Pytorch ZeRO as also very slow for saving checkpoints due to | |
# consolidate_state_dict(), but I've fixed it to save separate checkpoint per GPU. | |
- override /optimizer: adamw-zero | |
# FusedAdam doesn't seem to speed things up here, time per global step | |
# (i.e. batch size 512) on 8 A100s is around 2056ms for both AdamW and FusedAdam. | |
# This could be because each GPU is only doing the optimizer step for 1 / | |
# world_size of the parameters. | |
# Maybe the bottleneck here is the NCCL call to exchange parameters (ZeRO). | |
# - override /optimizer: adamw-apex-zero | |
# Can enable mlp_chekcpoint_lvl to fit batch_size 16 on A100 40GB | |
# model: | |
# config: | |
# # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"} | |
# mlp_checkpoint_lvl: 1 | |
datamodule: | |
# batch_size: 16 | |
batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} | |
trainer: | |
# strategy: null | |
# strategy: ${eval:"None if ${trainer.devices} == 1 else 'ddp_sharded'"} | |
strategy: | |
_target_: src.utils.ddp_zero1.DDPStrategyZero1 | |
find_unused_parameters: False | |
gradient_as_bucket_view: True | |
# TD [2022-08-03] Deepspeed makes the ppl curve go wild | |
# strategy: deepspeed_stage_1 | |