Spaces:
Running
on
Zero
Running
on
Zero
from typing import Any | |
import ignite | |
import ignite.distributed as idist | |
from ignite.engine import Engine, Events, EventsList | |
import torch | |
from omegaconf import OmegaConf | |
# TODO: move to utils or similar | |
def event_list_from_config(config) -> EventsList: | |
events = EventsList() | |
if isinstance(config, int): | |
events = events | Events.EPOCH_COMPLETED(every=config) | Events.COMPLETED | |
else: | |
for event in config: | |
if event["args"]: | |
events = events | Events[event["type"]](**event["args"]) | |
else: | |
events = events | Events[event["type"]] | |
return events | |
def global_step_fn(trainer: Engine, config: dict[str, Any]): | |
match config.get("type", None): | |
case "trainer epoch": | |
return lambda engine, event_name: trainer.state.epoch | |
case "trainer iteration": | |
return lambda engine, event_name: trainer.state.iteration | |
case _: | |
raise ValueError(f"Unknown global step type: {config['type']}") | |
# trainer iteration | |
gst = lambda engine, event_name: trainer.state.iteration | |
# # iteration per epoch | |
# gst_it_epoch = ( | |
# lambda engine, event_name: (trainer.state.epoch - 1) | |
# * engine.state.epoch_length | |
# + engine.state.iteration | |
# - 1 | |
# ) | |
# gst_it_iters = ( | |
# lambda engine, event_name: ( | |
# ( | |
# (trainer.state.epoch - 1) * trainer.state.epoch_length | |
# + trainer.state.iteration | |
# ) | |
# // every | |
# ) | |
# * engine.state.epoch_length | |
# + engine.state.iteration | |
# - 1 | |
# ) | |
# gst_ep_iters = lambda engine, event_name: ( | |
# ( | |
# (trainer.state.epoch - 1) * trainer.state.epoch_length | |
# + trainer.state.iteration | |
# ) | |
# // every | |
# ) | |
def log_basic_info(logger, config): | |
logger.info(f"Run {config['name']}") | |
logger.info(f"PyTorch version: {torch.__version__}") | |
logger.info(f"Ignite version: {ignite.__version__}") | |
if torch.cuda.is_available(): | |
# explicitly import cudnn as | |
# torch.backends.cudnn can not be pickled with hvd spawning procs | |
from torch.backends import cudnn | |
logger.info(f"GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}") | |
logger.info(f"CUDA version: {torch.version.cuda}") | |
logger.info(f"CUDNN version: {cudnn.version()}") | |
if idist.get_world_size() > 1: | |
logger.info("\nDistributed setting:") | |
logger.info(f"\tbackend: {idist.backend()}") | |
logger.info(f"\tworld size: {idist.get_world_size()}") | |
logger.info("\n") | |
logger.info("\n") | |
logger.info(f"Configuration: \n{OmegaConf.to_yaml(config)}") | |