File size: 2,418 Bytes
ac59957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import argparse

import pytorch_lightning as pl
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.plugins import TorchSyncBatchNorm

from core.memfof_lit import MEMFOFLit, DataModule
from config.parser import parse_args


def detect_cluster(args: argparse.Namespace) -> argparse.Namespace:
    if all(env in os.environ for env in ("SLURM_NTASKS_PER_NODE", "SLURM_JOB_NUM_NODES")):
        args.devices = int(os.environ["SLURM_NTASKS_PER_NODE"])
        args.num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
    elif all(env in os.environ for env in ("WORLD_SIZE", "LOCAL_WORLD_SIZE")):
        args.devices = int(os.environ["LOCAL_WORLD_SIZE"])
        args.num_nodes = int(os.environ["WORLD_SIZE"]) // args.devices
    return args


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cfg", type=str, help="experiment config file name", required=True)
    args = parse_args(parser)
    args = detect_cluster(args)

    if args.effective_batch_size % (args.num_nodes * args.devices) != 0:
        raise ValueError(
            f"Requested effective_batch_size={args.effective_batch_size} can not be split into {args.num_nodes} nodes with {args.devices} devices each."
        )

    args.batch_size = int(args.effective_batch_size / (args.num_nodes * args.devices))

    monitor = LearningRateMonitor()
    checkpoint = ModelCheckpoint(
        dirpath="ckpts",
        filename=args.name,
        monitor=args.monitor,
        every_n_train_steps=args.num_steps if args.monitor is None else None,
    )

    wandb_logger = WandbLogger(
        project="MEMFOF",
        config=vars(args),
        log_model=True,
        checkpoint_name=args.name,
    )

    plugins = [
        TorchSyncBatchNorm(),
    ]

    trainer = pl.Trainer(
        accelerator="gpu",
        devices=args.devices,
        strategy="ddp",
        num_nodes=args.num_nodes,
        logger=wandb_logger,
        gradient_clip_val=args.clip,
        precision="bf16-mixed",
        max_steps=args.num_steps,
        check_val_every_n_epoch=None,
        val_check_interval=args.val_steps,
        callbacks=[monitor, checkpoint],
        plugins=plugins,
    )

    model = MEMFOFLit(args)
    datamodule = DataModule(args)
    trainer.fit(model, datamodule)
    wandb.finish()