File size: 2,766 Bytes
9e15541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import Any

import ignite
import ignite.distributed as idist
from ignite.engine import Engine, Events, EventsList
import torch
from omegaconf import OmegaConf


# TODO: move to utils or similar
def event_list_from_config(config) -> EventsList:
    events = EventsList()
    if isinstance(config, int):
        events = events | Events.EPOCH_COMPLETED(every=config) | Events.COMPLETED
    else:
        for event in config:
            if event["args"]:
                events = events | Events[event["type"]](**event["args"])
            else:
                events = events | Events[event["type"]]

    return events


def global_step_fn(trainer: Engine, config: dict[str, Any]):
    match config.get("type", None):
        case "trainer epoch":
            return lambda engine, event_name: trainer.state.epoch
        case "trainer iteration":
            return lambda engine, event_name: trainer.state.iteration
        case _:
            raise ValueError(f"Unknown global step type: {config['type']}")

    # trainer iteration
    gst = lambda engine, event_name: trainer.state.iteration

    # # iteration per epoch
    # gst_it_epoch = (
    #     lambda engine, event_name: (trainer.state.epoch - 1)
    #     * engine.state.epoch_length
    #     + engine.state.iteration
    #     - 1
    # )
    # gst_it_iters = (
    #     lambda engine, event_name: (
    #         (
    #             (trainer.state.epoch - 1) * trainer.state.epoch_length
    #             + trainer.state.iteration
    #         )
    #         // every
    #     )
    #     * engine.state.epoch_length
    #     + engine.state.iteration
    #     - 1
    # )
    # gst_ep_iters = lambda engine, event_name: (
    #     (
    #         (trainer.state.epoch - 1) * trainer.state.epoch_length
    #         + trainer.state.iteration
    #     )
    #     // every
    # )


def log_basic_info(logger, config):
    logger.info(f"Run {config['name']}")
    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(f"Ignite version: {ignite.__version__}")
    if torch.cuda.is_available():
        # explicitly import cudnn as
        # torch.backends.cudnn can not be pickled with hvd spawning procs
        from torch.backends import cudnn

        logger.info(f"GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}")
        logger.info(f"CUDA version: {torch.version.cuda}")
        logger.info(f"CUDNN version: {cudnn.version()}")

    if idist.get_world_size() > 1:
        logger.info("\nDistributed setting:")
        logger.info(f"\tbackend: {idist.backend()}")
        logger.info(f"\tworld size: {idist.get_world_size()}")
        logger.info("\n")

    logger.info("\n")
    logger.info(f"Configuration: \n{OmegaConf.to_yaml(config)}")