Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,766 Bytes
9e15541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from typing import Any
import ignite
import ignite.distributed as idist
from ignite.engine import Engine, Events, EventsList
import torch
from omegaconf import OmegaConf
# TODO: move to utils or similar
def event_list_from_config(config) -> EventsList:
events = EventsList()
if isinstance(config, int):
events = events | Events.EPOCH_COMPLETED(every=config) | Events.COMPLETED
else:
for event in config:
if event["args"]:
events = events | Events[event["type"]](**event["args"])
else:
events = events | Events[event["type"]]
return events
def global_step_fn(trainer: Engine, config: dict[str, Any]):
match config.get("type", None):
case "trainer epoch":
return lambda engine, event_name: trainer.state.epoch
case "trainer iteration":
return lambda engine, event_name: trainer.state.iteration
case _:
raise ValueError(f"Unknown global step type: {config['type']}")
# trainer iteration
gst = lambda engine, event_name: trainer.state.iteration
# # iteration per epoch
# gst_it_epoch = (
# lambda engine, event_name: (trainer.state.epoch - 1)
# * engine.state.epoch_length
# + engine.state.iteration
# - 1
# )
# gst_it_iters = (
# lambda engine, event_name: (
# (
# (trainer.state.epoch - 1) * trainer.state.epoch_length
# + trainer.state.iteration
# )
# // every
# )
# * engine.state.epoch_length
# + engine.state.iteration
# - 1
# )
# gst_ep_iters = lambda engine, event_name: (
# (
# (trainer.state.epoch - 1) * trainer.state.epoch_length
# + trainer.state.iteration
# )
# // every
# )
def log_basic_info(logger, config):
logger.info(f"Run {config['name']}")
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"Ignite version: {ignite.__version__}")
if torch.cuda.is_available():
# explicitly import cudnn as
# torch.backends.cudnn can not be pickled with hvd spawning procs
from torch.backends import cudnn
logger.info(f"GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}")
logger.info(f"CUDA version: {torch.version.cuda}")
logger.info(f"CUDNN version: {cudnn.version()}")
if idist.get_world_size() > 1:
logger.info("\nDistributed setting:")
logger.info(f"\tbackend: {idist.backend()}")
logger.info(f"\tworld size: {idist.get_world_size()}")
logger.info("\n")
logger.info("\n")
logger.info(f"Configuration: \n{OmegaConf.to_yaml(config)}")
|